import os
import sys
import shutil
import argparse
import numpy as np
import multiprocessing as mp
from paddle import fluid
# import paddle.distributed.fleet as fleet
from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.base import role_maker
from tensorboardX import SummaryWriter

sys.path.append(
    os.path.realpath(os.path.join(os.path.dirname(__file__), '..')))
from config import XiaoduHiConfig
from interaction.attention_ctrl import AttentionController
from interaction.common.data import XiaoduHiDataloaderv2
from interaction.common.data_via_decord import XiaoduHiDecordLoader

MobileNetV2_Pretrained = 'pretrain_models/MobileNetV2_pretrained'

_update_step = 0
_eval_step = 0


def parse_args():
    parser = argparse.ArgumentParser(
        description='Train attention controller.')

    data_group = parser.add_argument_group('data')
    data_group.add_argument(
        '--video_tracking_dir', '-vd', type=str, default='data/clips',
        help='Directory of preprocessed video obj-tracking results. '
        'It is generated by scripts/collect_v2_data.py')
    data_group.add_argument(
        '--anno_dir', '-ad', type=str, default='data/annos',
        help='Directory of V2 annotations.')
    data_group.add_argument(
        '--wae_dir', '-wd', type=str, default='data/raw_wae',
        help='Directory of multimodal action embeddings, generated by '
        'scripts/collect_act_emb.py.')
    data_group.add_argument(
        '--train_dataset', '-tr', type=str, default='data/train.pkl',
        help='Path to training dataset which is generated by '
        'scripts/prepare_dataset.py -dv ds.')
    data_group.add_argument(
        '--test_dataset', '-te', type=str, default='data/test.pkl',
        help='Path to test dataset which is generated by '
        'scripts/prepare_dataset.py -dv ds.')
    data_group.add_argument(
        '--full_neg_train', type=str, default=None,
        help='Path to text file recording paths of full '
        'negative frames for training.')
    data_group.add_argument(
        '--full_neg_test', type=str, default=None,
        help='Path to text file recording paths of full '
        'negative frames for testing.')
    data_group.add_argument(
        '--inputs_type', type=str, default='visual_token',
        help='Inputs type of attention controller. It can be "visual_token", '
        '"instance", "without_inst_fm", "without_inst_cls", or '
        '"without_inst_pos". This is used for ablation study.')

    model_group = parser.add_argument_group('model')
    model_group.add_argument(
        '--yolov4_model_dir', type=str,
        default='tools/yolov4_paddle/inference_model',
        help='Directory of YOLOv4 model.')
    model_group.add_argument(
        '--init_params', type=str, default=None,
        help='The init parameter weights, useful for resume training or '
        'load init perception model weights.')
    model_group.add_argument(
        '--num_decoder_blocks', '-nb', type=int, default=6,
        help='Number of transformer decoder blocks of the attention model.')
    model_group.add_argument(
        '--num_heads', '-nh', type=int, default=8,
        help='Number of heads for multihead attention.')
    model_group.add_argument(
        '--model_dim', '-md', type=int, default=512,
        help='Dimension of key, query and value in self-attention.')
    model_group.add_argument(
        '--ffn_dim', '-fd', type=int, default=2048,
        help='Dimension of feedforward network after self-attention.')
    model_group.add_argument(
        '--normalize_before', default=False, action='store_true',
        help='Whether to use layer normalization at the beginning of '
        'each transformer decoder block.')
    model_group.add_argument(
        '--frame_emb_trainable', default=False, action='store_true',
        help='Whether frame embedding is trainable or use 1-D sine '
        'postional embedding as original transformer paper.')
    model_group.add_argument(
        '--use_last_act_loss', default=False, action='store_true',
        help='Whether to compute action loss only for last frame.')

    train_group = parser.add_argument_group('train')
    train_group.add_argument(
        '--distributed_training', default=False, action='store_true',
        help='Whether to use fleet distributed training.')
    train_group.add_argument(
        '--trigger_loss_coef', type=float, default=5.0,
        help='Coefficent of trigger loss.')
    train_group.add_argument(
        '--obj_loss_coef', type=float, default=1.0,
        help='Coefficent of interaction object loss.')
    train_group.add_argument(
        '--act_loss_coef', type=float, default=1.0,
        help='Coefficent of multimodal action NNL loss.')
    train_group.add_argument(
        '--data_workers_for_train', type=int, default=8,
        help='Number of data workers for training dataloader.')
    train_group.add_argument(
        '--data_workers_for_test', type=int, default=8,
        help='Number of data workers for test dataloader.')
    train_group.add_argument(
        '--epochs', type=int, default=10,
        help='The number of epochs to train the core controller.')
    train_group.add_argument(
        '--run_eval_after_epochs', type=int, default=1,
        help='After training this given epochs, run eval.')
    train_group.add_argument(
        '--dropout', type=float, default=0.0,
        help='Probability of dropout, default 0.0, meaning no dropout.')
    train_group.add_argument(
        '--lr', type=float, default=0.0001, help='Learning rate.')
    train_group.add_argument(
        '--bs', type=int, default=8, help='Batch size.')
    train_group.add_argument(
        '--gpu', type=int, default=0, help='GPU card to train model.')
    train_group.add_argument(
        '--data_worker_gpus_for_train', type=str, default='1',
        help='GPU cards for training data workers, separated by comma.')
    train_group.add_argument(
        '--data_worker_gpus_for_test', type=str, default='2',
        help='GPU cards for test data workers, separated by comma.')
    train_group.add_argument(
        '--save', type=str, default='save',
        help='Directory to save parameters and log files.')

    decord_group = parser.add_argument_group('decord')
    decord_group.add_argument(
        '--use_decord', default=False, action='store_true',
        help='Whether to use decord reader.')
    decord_group.add_argument(
        '--decord_readers', type=int, default=8,
        help='Number of parallel decord readers to use.')
    decord_group.add_argument(
        '--decord_detectors', type=int, default=2,
        help='Number of yolov4 detectors to use.')
    decord_group.add_argument(
        '--decord_post_workers', type=int, default=4,
        help='Number of post workers to use.')
    decord_group.add_argument(
        '--decord_ds_pkl', type=str, default='data/decord.pkl',
        help='Path to decord dataset pkl, generated by '
        'scripts/prepare_dataset.py -vt v2_decord')

    return parser.parse_args()


def convert_gpu_ids(ids):
    return [int(i) for i in ids.split(',')]


def train_epoch(exe, program, attention_ctrl, train_dataloader,
                log_steps=1, log_file=None, tb_writer=None,
                worker_index=None):
    global _update_step

    batch_id = 0
    for data in train_dataloader():
        if worker_index is not None:
            print('[worker index: {}] Train batch {}'.format(
                worker_index, batch_id))
        else:
            print('Train batch {}'.format(batch_id))
        batch_id += 1

        total_loss, trigger_loss, obj_loss, act_loss = \
            exe.run(program, feed=data,
                    fetch_list=[attention_ctrl.loss,
                                attention_ctrl.trigger_loss,
                                attention_ctrl.obj_loss,
                                attention_ctrl.act_loss])

        _update_step += 1
        if _update_step % log_steps == 0:
            log_str = ''
            if worker_index is not None:
                log_str = '[worker index: {}] '.format(worker_index)
            log_str += 'Total loss: {:.4f}'.format(total_loss[0])
            log_str += ', Trigger loss: {:.4f}'.format(trigger_loss[0])
            log_str += ', Obj loss: {:.4f}'.format(obj_loss[0])
            log_str += ', Act loss: {:.4f}'.format(act_loss[0])
            print(log_str)

        if tb_writer is not None:
            if worker_index is None:
                tb_writer.add_scalar('total_loss', total_loss[0], _update_step)
                tb_writer.add_scalar('trigger_loss', trigger_loss[0], _update_step)
                tb_writer.add_scalar('obj_loss', obj_loss[0], _update_step)
                tb_writer.add_scalar('act_loss', act_loss[0], _update_step)
            else:
                tb_writer.add_scalars(
                    'total_loss',
                    {'worker_{}'.format(worker_index): total_loss[0]},
                    _update_step)
                tb_writer.add_scalars(
                    'trigger_loss',
                    {'worker_{}'.format(worker_index): trigger_loss[0]},
                    _update_step)
                tb_writer.add_scalars(
                    'obj_loss',
                    {'worker_{}'.format(worker_index): obj_loss[0]},
                    _update_step)
                tb_writer.add_scalars(
                    'act_loss',
                    {'worker_{}'.format(worker_index): act_loss[0]},
                    _update_step)

        if log_file is not None:
            with open(log_file, 'a') as f:
                if worker_index is None:
                    f.write('{:.4f},{:.4f},{:.4f},{:.4f}\n'.format(
                        total_loss[0], trigger_loss[0], obj_loss[0],
                        act_loss[0]))
                else:
                    f.write('{} {:.4f},{:.4f},{:.4f},{:.4f}\n'.format(
                        worker_index, total_loss[0], trigger_loss[0],
                        obj_loss[0], act_loss[0]))


def parse_trigger_pred(pred, true, tolerant_level):
    delta_tp = delta_fp = delta_fn = 0
    for p, t in zip(pred, true):
        p_ids = np.where(p == 1)[0]
        t_ids = np.where(t == 1)[0]

        if len(p_ids) > 0 and len(t_ids) > 0:
            if abs(p_ids[0] - t_ids[0]) <= tolerant_level:
                delta_tp += 1
            else:
                delta_fp += 1
        elif len(p_ids) > 0 and len(t_ids) == 0:
            delta_fp += 1
        elif len(p_ids) == 0 and len(t_ids) > 0:
            delta_fn += 1

    return delta_tp, delta_fp, delta_fn


def calculate_avg_precison(precisions):
    # precisions in order of increasing threshold
    precisions = [p for p in reversed(precisions)]
    interp_precisions = []
    for i in range(len(precisions)):
        interp_precisions.append(max(precisions[i:]))
    return sum(interp_precisions) / len(interp_precisions)


def calculate_avg_recall(recalls):
    # recalls in order of increasing threshold
    recalls = [r for r in recalls]
    interp_recalls = []
    for i in range(len(recalls)):
        interp_recalls.append(max(recalls[i:]))
    return sum(interp_recalls) / len(interp_recalls)


def eval_model(exe, program, preds, act_nll_loss, test_dataloader,
               epoch_id, log_file=None, tb_writer=None, th_step=0.05,
               tolerant_levels=2, worker_index=None):
    global _eval_step

    th_lst, th = [], 0.05
    while th < 1.0:
        th_lst.append(th)
        th += th_step

    tp, fp, fn = dict(), dict(), dict()
    for lv in range(tolerant_levels+1):
        tp[lv], fp[lv], fn[lv] = dict(), dict(), dict()
        for th in th_lst:
            tp[lv][th] = fp[lv][th] = fn[lv][th] = 0

    batch_id = 0
    for data in test_dataloader():
        if worker_index is None:
            print('Eval batch {}'.format(batch_id))
        else:
            print('[worker index: {}] Eval batch {}'.format(
                worker_index, batch_id))

        batch_id += 1

        true_has_act = np.array(data[0]['has_act'])
        trigger_pred, act_loss = exe.run(
            program, feed=data, fetch_list=[preds[0], act_nll_loss])
        _eval_step += 1

        if tb_writer is not None:
            if worker_index is None:
                tb_writer.add_scalar('eval_act_loss', act_loss[0], _eval_step)
            else:
                tb_writer.add_scalars(
                    'eval_act_loss',
                    {'worker_{}'.format(worker_index): act_loss[0]},
                    _eval_step)

        for th in th_lst:
            pred_has_act = (trigger_pred > th).astype(np.int64)

            for lv in range(tolerant_levels+1):
                delta_tp, delta_fp, delta_fn = parse_trigger_pred(
                    pred_has_act, true_has_act, lv)
                tp[lv][th] += delta_tp
                fp[lv][th] += delta_fp
                fn[lv][th] += delta_fn

    eps = 1e-8
    precision, recall = dict(), dict()
    for lv in range(tolerant_levels+1):
        precision[lv], recall[lv] = dict(), dict()

        for th in th_lst:
            precision[lv][th] = (tp[lv][th] + eps) / \
                (tp[lv][th] + fp[lv][th] + eps)
            recall[lv][th] = (tp[lv][th] + eps) / \
                (tp[lv][th] + fn[lv][th] + eps)

    if epoch_id is not None:
        out = '--------------- Evaluate %d ---------------\n' % epoch_id
    else:
        out = '--------------- Evaluate ---------------\n'

    for lv in range(tolerant_levels+1):
        out += '----- Tolerant-%d -----\n' % lv
        for th in th_lst:
            out += 'Threshold: {:.2f}, Precision: {}, Recall: {}\n'.format(
                th, precision[lv][th], recall[lv][th])
        out += 'AP{:.2f}:{:.2f}:{:.2f}: {}\n'.format(
            th_lst[0], th_lst[-1], th_step,
            calculate_avg_precison([precision[lv][th] for th in th_lst]))
        out += 'AR{:.2f}:{:.2f}:{:.2f}: {}\n'.format(
            th_lst[0], th_lst[-1], th_step,
            calculate_avg_recall([recall[lv][th] for th in th_lst]))

    print(out)
    if log_file is not None:
        with open(log_file, 'a') as f:
            f.write(out)


def main(args):
    cfg = XiaoduHiConfig()
    cfg.scene_sensor_algo = 'yolov4'

    wae_ndarray = np.load(os.path.join(args.wae_dir, 'raw_wae.npy'))
    start_epoch = 0

    train_program = fluid.Program()
    startup_program = fluid.Program()

    with fluid.program_guard(train_program, startup_program):
        attention_ctrl = AttentionController(
            inputs_type=args.inputs_type,
            num_actions=wae_ndarray.shape[0],
            act_tr_dim=wae_ndarray.shape[1],
            act_emb_ndarray=wae_ndarray,
            num_frames=cfg.ob_window_len,
            tokens_per_frame=cfg.tokens_per_frame,
            visual_token_dim=cfg.visual_token_dim,
            model_dim=args.model_dim,
            num_decoder_blocks=args.num_decoder_blocks,
            num_heads=args.num_heads,
            ffn_dim=args.ffn_dim,
            dropout=args.dropout,
            normalize_before=args.normalize_before,
            frame_emb_trainable=args.frame_emb_trainable,
            trigger_loss_coef=args.trigger_loss_coef,
            obj_loss_coef=args.obj_loss_coef,
            act_loss_coef=args.act_loss_coef,
            use_last_act_loss=args.use_last_act_loss,
            mode='train')
        preds = attention_ctrl.predict()
        test_program = train_program.clone(for_test=True)

        optimizer = fluid.optimizer.AdamOptimizer(
            learning_rate=args.lr,
            regularization=fluid.regularizer.L2Decay(
                regularization_coeff=0.1))
        if args.distributed_training:
            optimizer = fleet.distributed_optimizer(optimizer)
            role = role_maker.PaddleCloudRoleMaker(is_collective=True)
            fleet.init(role)

        optimizer.minimize(attention_ctrl.loss)

    if args.distributed_training:
        place = fluid.CUDAPlace(int(os.environ.get('FLAGS_selected_gpus', 0)))
    else:
        place = fluid.CUDAPlace(args.gpu)

    exe = fluid.Executor(place)
    exe.run(startup_program)

    if args.inputs_type.startswith('inst_crop') and \
       args.inputs_type != 'inst_crop_wo_crop':
        fluid.io.load_vars(
            exe, MobileNetV2_Pretrained,
            main_program=train_program,
            predicate=lambda v: os.path.exists(
                os.path.join(MobileNetV2_Pretrained, v.name)))
        print('Loaded weights from {}'.format(MobileNetV2_Pretrained))

    if args.init_params is not None:
        base = os.path.basename(args.init_params)
        if base.startswith('epoch_'):
            start_epoch = int(base[len('epoch_'):]) + 1

        tb_state = os.path.join(args.init_params, 'tb_state.txt')
        if os.path.exists(tb_state):
            global _update_step
            global _eval_step
            with open(tb_state, 'r') as f:
                update_step, eval_step = f.readline().split(' ')
                _update_step = int(update_step)
                _eval_step = int(eval_step)

        fluid.io.load_vars(
            exe, args.init_params,
            main_program=train_program,
            predicate=lambda v: os.path.exists(
                os.path.join(args.init_params, v.name)))
        print('Loaded weights from {}'.format(args.init_params))

    if args.distributed_training:
        train_worker_gpus = [int(os.environ.get('FLAGS_selected_gpus', 0))]
        test_worker_gpus = train_worker_gpus
    else:
        train_worker_gpus = convert_gpu_ids(args.data_worker_gpus_for_train)
        test_worker_gpus = convert_gpu_ids(args.data_worker_gpus_for_test)

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if not args.use_decord:
        train_dataloader = XiaoduHiDataloaderv2(
            attention_ctrl.feed_list, [place], args.yolov4_model_dir,
            args.video_tracking_dir, args.train_dataset,
            full_neg_txt=args.full_neg_train,
            batch_size=args.bs,
            num_workers=args.data_workers_for_train,
            worker_gpus=train_worker_gpus,
            roi_feat_resolution=cfg.roi_feat_resolution,
            ob_window_len=cfg.ob_window_len,
            interval=cfg.interval,
            tokens_per_frame=cfg.tokens_per_frame,
            visual_token_dim=cfg.visual_token_dim,
            augment=False,
            resample_negs_per_epoch=True)
        test_dataloader = XiaoduHiDataloaderv2(
            attention_ctrl.feed_list, [place], args.yolov4_model_dir,
            args.video_tracking_dir, args.test_dataset,
            full_neg_txt=args.full_neg_test,
            batch_size=args.bs,
            num_workers=args.data_workers_for_test,
            worker_gpus=test_worker_gpus,
            roi_feat_resolution=cfg.roi_feat_resolution,
            ob_window_len=cfg.ob_window_len,
            interval=cfg.interval,
            tokens_per_frame=cfg.tokens_per_frame,
            visual_token_dim=cfg.visual_token_dim,
            augment=False,
            resample_negs_per_epoch=False,
            for_test=True)

        test_dataloader.save_to_txt(
            os.path.join(args.save, 'eval_data.txt'), dt=200)
    else:
        train_dataloader = XiaoduHiDecordLoader(
            attention_ctrl.feed_list, [place],
            args.yolov4_model_dir, args.decord_ds_pkl,
            decord_readers=args.decord_readers,
            yolov4_detectors=args.decord_detectors,
            post_workers=args.decord_post_workers,
            batch_size=args.bs,
            detector_gpus=train_worker_gpus,
            roi_feat_resolution=cfg.roi_feat_resolution,
            tokens_per_frame=cfg.tokens_per_frame,
            visual_token_dim=cfg.visual_token_dim,
            for_test=False)
        test_dataloader = XiaoduHiDecordLoader(
            attention_ctrl.feed_list, [place],
            args.yolov4_model_dir, args.decord_ds_pkl,
            decord_readers=args.decord_readers,
            yolov4_detectors=args.decord_detectors,
            post_workers=args.decord_post_workers,
            batch_size=args.bs,
            detector_gpus=test_worker_gpus,
            roi_feat_resolution=cfg.roi_feat_resolution,
            tokens_per_frame=cfg.tokens_per_frame,
            visual_token_dim=cfg.visual_token_dim,
            for_test=True)

    train_dataloader.start_workers()
    test_dataloader.start_workers()

    train_log = os.path.join(args.save, 'loss.csv')
    eval_log = os.path.join(args.save, 'eval.txt')
    with open(os.path.join(args.save, 'args.txt'), 'w') as f:
        f.write(str(args))

    tb_writer = SummaryWriter(
        logdir=os.path.join(args.save, 'logdir'),
        purge_step=None if _update_step == 0 else _update_step)

    worker_index = None if not args.distributed_training \
        else fleet.worker_index()

    # if worker_index == 0:
    #     eval_model(exe, test_program, preds, attention_ctrl.act_loss,
    #                test_dataloader, -1, log_file=eval_log,
    #                tb_writer=tb_writer, worker_index=worker_index)
    for epoch_id in range(start_epoch, args.epochs):
        print('--------------- Epoch %d ---------------' % epoch_id)
        train_epoch(exe, train_program, attention_ctrl, train_dataloader,
                    log_file=train_log, tb_writer=tb_writer,
                    worker_index=worker_index)

        save_dir = os.path.join(args.save, 'epoch_{}'.format(epoch_id))
        shutil.rmtree(save_dir, ignore_errors=True)
        os.mkdir(save_dir)
        fluid.io.save_params(exe, save_dir, main_program=train_program)

        if epoch_id > 0 and epoch_id % args.run_eval_after_epochs == 0:
            eval_model(exe, test_program, preds, attention_ctrl.act_loss,
                       test_dataloader, epoch_id, log_file=eval_log,
                       tb_writer=tb_writer)

        tb_state = os.path.join(save_dir, 'tb_state.txt')
        with open(tb_state, 'w') as f:
            f.write('{} {}'.format(_update_step, _eval_step))

    if epoch_id % args.run_eval_after_epochs != 0:
        eval_model(exe, test_program, preds, attention_ctrl.act_loss,
                   test_dataloader, epoch_id, log_file=eval_log,
                   tb_writer=tb_writer, worker_index=worker_index)

    train_dataloader.stop_workers()
    test_dataloader.stop_workers()


if __name__ == '__main__':
    if len(sys.argv) == 1:
        sys.argv.append('-h')
    args = parse_args()
    mp.set_start_method('spawn')
    main(args)
